{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UPvC7G5kUJYG"
   },
   "source": [
    "#  Introduction to Gaussian Processes\n",
    "\n",
    "In the world of cheminformatics and machine learning, models are often trees (random forest, XGBoost, etc.) or artifical neural networks (deep neural networks, graph convolutional networks, etc.). These models are known as \"Frequentist\" models. However, there is another category known as Bayesian models. Today we will be experimenting with a Bayesian model implemented in scikit-learn known as gaussian processes (GP). For a deeper dive on GP, there is a great [tutorial paper](https://arxiv.org/pdf/2009.10862.pdf) on how GP works for regression. There is also an [academic paper](https://doi.org/10.1002/cmdc.200700041) that applies GP to a real world problem.\n",
    "\n",
    "As a short intro, GP allows us to build up our statistical model using an infinite number of Gaussian functions over our n-dimensional space, where n is the number of features. However, we pick these functions based on how well they fit the data we pass it. We end up with a statistical model built from an *ensemble* of Gaussian functions which can actually vary quite a bit. The result is that for points we have trained the model on, the variance in our ensemble should be very low. For test set points close to the training set points, the variance should be higher but still low as the ensemble was picked to predict well in its neighborhood. For points far from the training set points, however, we did not pick our ensemble of Gaussian functions to fit them so we'd expect the variance in our ensemble to be high. In this way, we end up with a statistical model that allows for a natural generation of uncertainty.\n",
    "\n",
    "## Colab\n",
    "\n",
    "This tutorial and the rest in the sequences are designed to be done in Google colab. If you'd like to open this notebook in colab, you can use the following link. \n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepchem/deepchem/blob/master/examples/tutorials/Introduction_to_Gaussian_Processes.ipynb)\n",
    "\n",
    "## Setup\n",
    "\n",
    "The first step is to get DeepChem up and running. We recommend using Google Colab to work through this tutorial series. You'll need to run the following commands to get DeepChem installed on your colab notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "tJGVKFHnQieU",
    "outputId": "a8befbcf-fcb8-47ef-f9f3-89a9f1a274b0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: deepchem in /home/ozone/miniconda3/envs/mol/lib/python3.7/site-packages (2.5.0.dev20210319222130)\n",
      "Requirement already satisfied: scikit-learn in /home/ozone/miniconda3/envs/mol/lib/python3.7/site-packages (from deepchem) (1.0.2)\n",
      "Requirement already satisfied: numpy in /home/ozone/miniconda3/envs/mol/lib/python3.7/site-packages (from deepchem) (1.19.1)\n",
      "Requirement already satisfied: pandas in /home/ozone/miniconda3/envs/mol/lib/python3.7/site-packages (from deepchem) (1.3.1)\n",
      "Requirement already satisfied: joblib in /home/ozone/miniconda3/envs/mol/lib/python3.7/site-packages (from deepchem) (1.1.0)\n",
      "Requirement already satisfied: scipy in /home/ozone/miniconda3/envs/mol/lib/python3.7/site-packages (from deepchem) (1.6.2)\n",
      "Requirement already satisfied: python-dateutil>=2.7.3 in /home/ozone/miniconda3/envs/mol/lib/python3.7/site-packages (from pandas->deepchem) (2.8.2)\n",
      "Requirement already satisfied: pytz>=2017.3 in /home/ozone/miniconda3/envs/mol/lib/python3.7/site-packages (from pandas->deepchem) (2021.3)\n",
      "Requirement already satisfied: six>=1.5 in /home/ozone/miniconda3/envs/mol/lib/python3.7/site-packages (from python-dateutil>=2.7.3->pandas->deepchem) (1.16.0)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /home/ozone/miniconda3/envs/mol/lib/python3.7/site-packages (from scikit-learn->deepchem) (2.2.0)\n"
     ]
    }
   ],
   "source": [
    "%pip install --pre deepchem"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7VEHj_Y2WWyC"
   },
   "source": [
    "## Gaussian Processes\n",
    "\n",
    "As stated earlier, GP is already implemented in scikit-learn so we will be using DeepChem's scikit-learn wrapper. SklearnModel is a subclass of DeepChem's Model class. It acts as a wrapper around a sklearn.base.BaseEstimator.\n",
    "\n",
    "Here we import deepchem and the GP regressor model from sklearn."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "v-h1CDpoWWOE"
   },
   "outputs": [],
   "source": [
    "import deepchem as dc\n",
    "from sklearn.gaussian_process import GaussianProcessRegressor\n",
    "from sklearn.gaussian_process.kernels import RBF, WhiteKernel\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Nlu44FV8YFMM"
   },
   "source": [
    "## Loading data\n",
    "\n",
    "Next we need a dataset that presents a regression problem. For this tutorial we will be using the BACE dataset from MoleculeNet.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "kbJ3UNt8YWk9"
   },
   "outputs": [],
   "source": [
    "tasks, datasets, transformers = dc.molnet.load_bace_regression(featurizer='ecfp', splitter='random')\n",
    "train_dataset, valid_dataset, test_dataset = datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "C5vePfTzYmHG"
   },
   "source": [
    "I always like to get a close look at what the objects in my code are storing. We see that tasks is a list of tasks that we are trying to predict. The transformer is a NormalizationTransformer that normalizes the outputs (y values) of the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "feTbZu-tYpV-",
    "outputId": "59fa279b-e675-4d81-b4f3-4ee94cc10b42"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The tasks are: ['pIC50']\n",
      "The transformers are: [<deepchem.trans.transformers.NormalizationTransformer object at 0x7fc04401b190>]\n",
      "The transformer normalizes the outputs (y values): True\n"
     ]
    }
   ],
   "source": [
    "print(f'The tasks are: {tasks}')\n",
    "print(f'The transformers are: {transformers}')\n",
    "print(f'The transformer normalizes the outputs (y values): {transformers[0].transform_y}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vLMr7VmEZgRi"
   },
   "source": [
    "Here we see that the data has already been split into a training set, a validation set, and a test set. We will train the model on the training set and test the accuracy of the model on the test set. If we were to do any hyperparameter tuning, we would use the validation set. The split was ~80/10/10 train/valid/test."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "L0r7mgS4ZhSd",
    "outputId": "41399618-4193-4b5a-cddb-821881352324"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<DiskDataset X.shape: (1210, 1024), y.shape: (1210, 1), w.shape: (1210, 1), task_names: ['pIC50']>\n",
      "<DiskDataset X.shape: (151, 1024), y.shape: (151, 1), w.shape: (151, 1), ids: ['Fc1ncccc1-c1cc(ccc1)C1(N=C(N)N(C)C1=O)c1cn(nc1)CC(CC)CC'\n",
      " 'S1(=O)(=O)N(c2cc(cc3n(cc(CC1)c23)CC)C(=O)NC(Cc1ccccc1)C(=O)C[NH2+]C1CCOCC1)C'\n",
      " 's1ccnc1-c1cc(ccc1)CC(NC(=O)[C@@H](OC)C)C(O)C[NH2+]C1CC2(Oc3ncc(cc13)CC(C)(C)C)CCC2'\n",
      " ...\n",
      " 'S(=O)(=O)(Nc1cc(cc(c1)C(C)(C)C)C1([NH2+]CC(O)C(NC(=O)C)Cc2cc(F)cc(F)c2)CCCCC1)C'\n",
      " 'O=C1N(C)C(=N[C@]1(c1cc(nc(c1)CC)CC)c1cc(ccc1)-c1cncnc1)N'\n",
      " 'Clc1cc2CC(N=C(NC(Cc3ccccc3)C=3NC(=O)c4c(N=3)cccc4)c2cc1)(C)C'], task_names: ['pIC50']>\n",
      "<DiskDataset X.shape: (152, 1024), y.shape: (152, 1), w.shape: (152, 1), ids: ['Clc1ccc(cc1)CC(NC(=O)C)C(O)C[NH2+]C1CC2(Oc3ncc(cc13)CC(C)(C)C)CCC2'\n",
      " 'Fc1cc(cc(F)c1)CC(NC(=O)c1cc(cc(Oc2ccc(F)cc2)c1)C(=O)N(CCC)CCC)C(O)C[NH2+]Cc1cc(OC)ccc1'\n",
      " 'O1c2c(cc(cc2)C2CCCCC2)C2(N=C(N)N(C)C2=O)CC1(C)C' ...\n",
      " 'S(=O)(=O)(N(C)c1cc(cc(c1)COCC([NH3+])(Cc1ccccc1)C(F)F)C(=O)NC(C)c1ccc(F)cc1)C'\n",
      " 'O1CCCC1CN1C(=O)C(N=C1N)(C1CCCCC1)c1ccccc1'\n",
      " 'Fc1cc(cc(c1)C#C)CC(NC(=O)COC)C(O)C[NH2+]C1CC2(Oc3ncc(cc13)CC(C)(C)C)CCC2'], task_names: ['pIC50']>\n"
     ]
    }
   ],
   "source": [
    "print(train_dataset)\n",
    "print(valid_dataset)\n",
    "print(test_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Jdh7Hej8aUMQ"
   },
   "source": [
    "## Using the SklearnModel\n",
    "\n",
    "Here we first create the model using the GaussianProcessRegressor we imported from sklearn. Then we wrap it in DeepChem's SklearnModel. To learn more about the model, you can either read the sklearn API or run help(GaussianProcessRegressor) in a code block.\n",
    "\n",
    "As you see, the values I picked for the parameters seem awfully specific. This is because I needed to do some hyperparameter tuning beforehand to get model that wasn't wildly overfitting the training set. You can learn more about how I tuned the model in the Appendix at the end of this tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "f3U2Y0Q-aUqi"
   },
   "outputs": [],
   "source": [
    "output_variance = 7.908735015054668\n",
    "length_scale = 6.452349252677817\n",
    "noise_level = 0.10475507755839343\n",
    "kernel = output_variance**2 * RBF(length_scale=length_scale, length_scale_bounds='fixed') + WhiteKernel(noise_level=noise_level, noise_level_bounds='fixed')\n",
    "alpha = 4.989499481123432e-09\n",
    "\n",
    "sklearn_gpr = GaussianProcessRegressor(kernel=kernel, alpha=alpha)\n",
    "model = dc.models.SklearnModel(sklearn_gpr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "h7UCSgZ7bsOe"
   },
   "source": [
    "Then we fit our model to the data and see how it performs both on the training set and on the test set. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "8rgzad5ObsTb",
    "outputId": "70a7001f-a6f3-4749-d647-3c5554d7c555"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training set score: {'mean_squared_error': 0.0457129375800123, 'r2_score': 0.9542870624199877}\n",
      "Test set score: {'mean_squared_error': 0.20503945381118496, 'r2_score': 0.7850242035806018}\n"
     ]
    }
   ],
   "source": [
    "model.fit(train_dataset)\n",
    "metric1 = dc.metrics.Metric(dc.metrics.mean_squared_error)\n",
    "metric2 = dc.metrics.Metric(dc.metrics.r2_score)\n",
    "print(f'Training set score: {model.evaluate(train_dataset, [metric1, metric2])}')\n",
    "print(f'Test set score: {model.evaluate(test_dataset, [metric1, metric2])}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "b_XkYqGLcdDT"
   },
   "source": [
    "## Analyzing the Results\n",
    "\n",
    "We can also visualize how well the predicted values match up to the measured values. First we need a function that allows us to obtain both the mean predicted value and the standard deviation of the value. This is done by sampling 100 predictions from each set of inputs X and calculating the mean and standard deviation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "lnH5zcI9TF_y"
   },
   "outputs": [],
   "source": [
    "def predict_with_error(dc_model, X, y_transformer):\n",
    "    samples = model.model.sample_y(X, 100)\n",
    "    means = y_transformer.untransform(np.mean(samples, axis=1))\n",
    "    stds = y_transformer.y_stds[0] * np.std(samples, axis=1)\n",
    "\n",
    "    return means, stds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For our training set, we see a pretty good correlation between the measured values (x-axis) and the predicted values (y-axis). Note that we use the transformer from earlier to untransform our predicted values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 283
    },
    "id": "atm3snwocamM",
    "outputId": "20f4c329-78fb-4384-dfe6-4b27d95c1d77"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7fc0431b45d0>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAg30lEQVR4nO3dbYxb53Un8P8Z6qoivbUp2WPDpi1LzQZyYCvSWLO2G2GNWk6spH7JVHYqa9dAtltU/RCkthuoGAPqygKMSIGKrPthsajgtDVgV7ViOVN71foFkdq0SqVgJiOtpFpC4BfJpuxoshKVWkNLHM7ZD+RlSM69l5f3XvLeh/z/AGFmSA7vI8o+fHie85xHVBVERGSegbgHQEREwTCAExEZigGciMhQDOBERIZiACciMtS8bl7smmuu0SVLlnTzkkRExpuYmPiFqg42397VAL5kyRKMj49385JERMYTkVNOtzOFQkRkKAZwIiJDMYATERmKAZyIyFAM4EREhmIAJyJKsLHJPKzBJcud7mMAJyJKqLHJPJ565SgkNW++0/0M4ERECbXjjZMolsqu9zOAExEl1JlC0fP+lgFcRP5SRM6KyLG62xaJyFsi8rPq14URjJWIiOrckE173u9nBv7XAL7cdNsogB+q6mcB/LD6MxERRWjT2mVIWynX+1sGcFX9EYBzTTd/FcDz1e+fBzAScHxERORiZCiHbeuWQ8szl53uD5oDv05VPwKA6tdr3R4oIhtFZFxExqempgJejoioP40M5VCaev+o030dX8RU1Z2qOqyqw4ODc7ohEhFRQEED+M9F5HoAqH49G92QiIjIj6AB/FUAX69+/3UAfxfNcIiIyC8/ZYS7APwrgGUi8qGI/D6A7QC+JCI/A/Cl6s9ERNRFLU/kUdUNLnfdG/FYiIioDdyJSURkKAZwIiJDMYATERmKAZyIyFAM4EREhmIAJyIyFAM4EZGhGMCJiAzFAE5EZCgGcCIiQzGAExEZigGciMhQDOBERIZiACciMhQDOBGRoRjAiYgMxQBORGSolifyEBH1orHJPHa8cRJnCkXckE1j09plGBnKxT2stoSagYvI4yJyTESOi8gTEY2JiKijxibzeOqVo8gXilAA+UIRT71yFGOT+biH1pbAAVxEbgPwBwDuALACwAMi8tmoBkZE1Ck73jiJYqnccFuxVMaON07GNKJgwszAPwfgoKpOq+oMgH8C8DvRDIuIqHPOFIpt3Z5UYQL4MQB3i8jVIpIB8NsAbmp+kIhsFJFxERmfmpoKcTkiomhkM1ZbtydV4ACuqm8D+A6AtwC8DuAIgBmHx+1U1WFVHR4cHAw8UCKiqKi2d3tShapCUdXvAfgeAIjItwF8GMWgiIhaCVNFcqFYauv2pAoVwEXkWlU9KyKLAawD8JvRDIuI+k07AdmuIrEXIu0qEgC+gvgN2TTyDvnuG7LpEH+D7gtbB75HRK4GUALwDVU9H8GYiKjPeAVkAHMCu1cViZ8AvmntsobrAUDaSmHT2mUR/Y26I2wK5T9HNRAi6g1eM+nNY0ex69AHKKsiJYINd96EZ0aWuwbkra8dx4ViCbPV3HS+UMQf7z5c+7mZ3yoSezymb+ThTkwiiozXTHr81Dm8cPB07bFl1drPboH3/PTcnLRb8AaAq9L+q0hGhnLGBexmDOBEFBmv1MbHFz51/J1dhz5wzUm3SyT0UxiFzayIKDJuM+l8oYiyS41eWRWb1i5D2kqFvn7BYcbeyxjAiSgyQao4UiIYGcph27rlyGXTCDOJNq2KJCymUIgosOYFy3tuGcSeifycNIqXDXdWNnDX56RXb9/nmVJJCTB/Xsr4KpKwGMCJKBCnBcs9E3k8vCqH/SemfOe0h29ehLHJPLa+dry2aJmxBmANCEouK5ZXpi1sefBW7HjjJPKFIlIiDc2oTF+c9IsBnIgCcVuw3H9iCpvWLsMTLx329Tzf2n0YAwOCUvlXwXq6NIsBj1xKYbpUC9JhNvSYjjlwIgrEa8HySZ/BGwDKiobgbZvVSn7ciZ3r7pW2sEFxBk5Enpw25gDAgIhrZUlUPaHcnn/J1ZUA3ittYYNiACciV0557k3fPwKIe3DthoPvVrp29EpPk6CYQiEiR2OTeTy5+/CcFEVpVh1THmFYqbmpEssjCW6/eTjVj/dTNQoDOFGfGpvMY/X2fVg6uhert+9rOA9ybDKPTS8f6Up/7Cvmp7DjkRVYWHeYQjZtYcfXVrjmwO3bm+vHc9k0tq1b3hcLmABTKER9yQ7Q9kw6XyjiiZcOY/zUOTwzshxbXzse+SzbzcXLlUXHLQ/eOifwNvdPsdm140Bv9DQJigGcqA80L0Sev3jJMUDbwdKpiVQnuZX/PTOyHAAcOxgSINrFhYjh4WEdHx/v2vWIaO5CZJLlsmkcGF0T9zASR0QmVHW4+XbmwIl6nFOtdFL1S/lfVJhCIepRzdvT4/Zr8wZwaWbW8zH9Uv4XFQZwoh5Rn+e+Km3h3y/NoOx1+kGXtQre/VT+F5VQKRQReVJEjovIMRHZJSILohoYEfln57nzhSIUQKFYSlTwbkUAPLyqf6tJggocwEUkB+CPAAyr6m0AUgAejWpgROSfSXluJwpg/4mpuIdhnLAplHkA0iJSApABcCb8kIiomddBwQAiOY4sblzAbF/gAK6qeRH5MwCnARQBvKmqbzY/TkQ2AtgIAIsXLw56OaK+5XVQMIDEdt7LWANQSMMnAyslmCmrY7MrLmC2L3AAF5GFAL4KYCmAAoDvi8hjqvpC/eNUdSeAnUClDjz4UIn6k1vL1K2vHcenpdlEpk7SVgrfXlfZbOPUybC5Lp0LmMGESaF8EcB7qjoFACLyCoAvAHjB87eIqC1uqYWklAeKAFrt3V1WRa4pxeO2MOmVEiJ/wgTw0wDuEpEMKimUewFwmyVRxNxapibBAIDv/u5K1+Drlrvv5/4lUQpchaKqhwC8DOCnAI5Wn2tnROMioiqnlqlJIAJ8d7138K4vbbRz9/VdDymcUFUoqroFwJaIxkLU1zaPHXVs2jQylMP4qXO1+xJDvc+d9DrujLPvaHAnJlGXOaUVmtumllUbfn7x4OnIjimLSquqkX4/7qwbGMCJusitJPDTGedKkr85dBqq0Z0xGRUrJS2rRvr9uLNuYDdCoi5ySyu4ZUZmYw7eGWsAz65fiWz6V6flLMxY2PHIipZpkH4/7qwbOAMn6iLT0gfTpVnseOMknn5o7mk5rdiPZ7lg5zCAE3VRkksC3bidluMHywU7iwGcKKBW/UmabR47io8umBW8baweSSYGcKIAvPqTOAW5zWNHHQ/nNUlz+qfdNzCKHgM4UQBeNc72/fWBbdehD+IYZqTqq0fafQOjzmAVClEAbouR+UIRT750eM7uw0RtwAmguXqk1RsYdQcDOFEAXrXMzaE6id0C25HLprFt3fKGmTU36SQDAzhRAEntTxK1AakE5R1vnGzoYeL2BsZNOt3FAE4UwMhQDg+vyiElEvdQQrNSgsfuWgwrNffvYm8kam5ExU06ycAAThTA2GQeeybyxue2AaBUVuz9vx9h/X+6CblsGgI4vjHV57hHhnLYtm557fFOaRbqPFahUGziKEMLes3m35u+PGN8brve+ekS9kzka0F46ehex8fV57idNumwtLC7GMApFp0uQ3MKJAACXdNprL2ofrNOkEZULC3sPqZQKBadLENzO0jg6VePB7qm01h7lR20g+S4WVrYfZyBU8d4fZzuZBmaWyBxC8KtdhiaPuPOtfF3sHPfQRpRsbSw+xjAqSNafZzuZK/odgPGgAiWju7FDdk07rllEHsm8g3jFiSvH7dfuWwaB0bXAACWju5t+feoX5RttxEV+393X+AUiogsE5HDdX9+KSJPRDg2Mlirj9OdLENrN2CUVWuplhcPnp4zblODtzXQeOiCn9clFyLYsrSw+8IcanxSVVeq6koAqwBMA/hBVAMjs7X6ON3JMrQwm2xMDdaOmioB/bwu99wyGPhyLC3svqhSKPcCeEdVT0X0fGQ4Px+nO9Ur2il/O315BuenS5FfK8lKZW1oAVv/urjlxPefmAp1Tfb/7q6oqlAeBbDL6Q4R2Sgi4yIyPjUV7j8OMkecH6edFk+3PHhrX2x9b9b8SWhkKIcDo2uaJ+euj6dkCx3ARWQ+gIcAfN/pflXdqarDqjo8OBj84xmZJczH6bHJPFZv34elo3uxevu+hh4cfn7XqYQQALatWx7wb5NcdiB229Lfbs8SLjiaJYoUylcA/FRVfx7Bc1EPCfJxOuxmEK/F0wOja/D0q8dRKJqXShmQSl8SAMimrTlnVDa/boD3J55Na5e19XhKpigC+Aa4pE+ov0SxjdorADs9l9+abTs1YGrvqVmtBFi3TzEjQzmMnzqHXYc+QFkVKRE8vMr9DZQHDveGUAFcRDIAvgTgD6MZDpkqqm3U7WwGcbqmW822nRooGLyQWSyV8a3dRwDMfU2bm2uVVbFnIo/hmxd5BnEGbLOFyoGr6rSqXq2qF6IaEJkpqm3UfnKzdo78iZcOO9ZsN0+y61MDmflmL2SWVRvautq4jb0/sRcKRSKqbdStqlfqFyndKNCwePrwqhx2vHESS0b34uJl83uaOAVmt9fD9DYA5I1b6SkSUW2jbpWb9dNYqn77uNPiXi9ofmNMiTj2Ju+FAyfIHQM4RSLKqgav3GyrGb2VEly8NFPrbVKYvmxk8E6JYMOdN9UWJZs1vzG6HSzRCwdOkDsGcIqEHXDry/QWWNFk6OorTUQAr5hULmvt+klLH2TTlq8SRisl2PHICowM5TB88yJfb4xuHQfD9Dah5GMOnCJ1aWa29v356ZLjgls7mjfmzLaYUM563x2rB1Zc33I36MKMVQvegP8NUWwk1Z9Eu/gRa3h4WMfHx7t2Pequ1dv3uc4C7Zx0O8Ym8/jW7iM9kwbIVdvV1tdqb7jzJjwzEs0OUR5n1rtEZEJVh5tvZwqFIhNlQ3975u03eK/+zCIceOdc29fppnyh2HatdjtY191/mEKhyETZX6PdY8x+nPDgbWOtNkWJAZwi024e1qtpVbsLkCYnWdgBkIJiCoUi005/jVZb700+xqxd7ABIQTGAU6T85mFbbf32Ct5+y/FMwEoRCoMBnGLh1TXQKyecEklc8E5bA1hgpVCYLmHAZUdkvZQIZlVZKUKhMYBTV41N5vH0q8dd778qbXnmv5NSUiiAYwButXXfqyUsUbsYwA1lSs1v/TizGQuffDqDksdunIuXZ7o4uuDe236/4+3N6wDZjAVV4EKxlOh/JzITA7iBouq93WnN4/RzqHCpnIwZthdB5e/GPtsUN5YRGsiU3s/t1nKbQoGOv9ZhzgWl/sEZuIGi3PHYSe2Mx0oJZspqTOlgJ19rUz5hUfw4AzeQKSeK+x3PFfNTgJpV993J19qUT1gUP87ADZTEE8XrFyuvSlsQqeS8vTbkCCpHnJl2Sk6nX2tTPmFR/MIeapwF8ByA21D5//S/q+q/RjAu8pC0E8WbP/LX12l7zaoVMCZ4d7N2O6rTjaj3hWonKyLPA/hnVX1OROYDyKhqwe3xbCfbe0xv+Wq3eN1/Ygr5QhEDMrfnuJUSXDF/Hi4US7VPF4XpzpUFOtWSs368v0XeTlZErgRwN4D/BgCqehnA5aDPR+aw0yX5QtHoniV+zs7MWAMo1Z3yU//polOLi0n7hEXJFSaF8hsApgD8lYisADAB4HFVvVj/IBHZCGAjACxevDjE5fpLtzfq+L1ec6AzNXhbKWnIY7uVPF6aUc9PF/biYtT/NqwlJz/CBPB5AG4H8E1VPSQifw5gFMCf1j9IVXcC2AlUUighrtc3ul1G5nU9ALXZttvJ56ZZmLGw5cFbG15LtwVCP39fLi5SXMIE8A8BfKiqh6o/v4xKAKeQvMrIOhHA3a639bXj+LQ0W7vP5ODd6lg3t4VDP29aXFykuASuA1fVjwF8ICL259B7AfxbJKPqc90uI3N73vPTpZ7YSWkNSMuyP7fDKDbceZPnQcRxl29SfwtbB/5NAC9WK1DeBfB74YdE3S4jy2YsX31KjCWtH+K1cDh886I5Ne6drEIh8oun0idQt8vIVm59M3E9tqPWKoVClGQ8ld4g3S4j6/XgDXChkXoTA3hCdbOMrFeqSwC41qVzoZF6EZtZUc8Eb6ASvJtT3lxopF7FAE5IifMqn0ilgsM0ikrOW6pfuQWdehVTKOQ6A1cFSgmdnQ8I8GvzBlAszc65jwuW1C84AyfjPLt+Jd7ddj+2rfu8Y+020yXULzgDJ6MszFi1dAibPlG/YwAnxxaqSSQAtjx4a8NtbPpE/YwB3CCd6lDolktOGgXPhCSqxwBuiE50KLTfEJIWvLNpy3FzUY613EQNuIhpCLeOgU+8dBirt+/D2GS+reez3xCceq50kpUSPLt+pWswzmXTePqhW7k4SeQDZ+CG8NoKHmQ27naAQafteGRFbYxuBzNzcZLIHwZwQ7h1KLS12y+8271BBgB8d/1K3xUkXJwkao0B3BCb1i5zPLOxnltQbl78vOeWQQx0uf/Jf7lr8ZyAzCBNFA4DuCHqZ6xuM3G7YVPDocNS2VFpyxeKeOHg6Y6Pt9n+E1NdvyZRr+MipkFGhnI4MLoGz65f6brI17w4mZSd8N1eLCXqBwzgBhoZymHbuuWODZviWpxsRYC2K2WIyBsDuKHGT53Dxxc+hQL4+MKnGD91DkByDy5QVNI/RBSdUDlwEXkfwL8DKAOYcTryh6K3eexoQx67rIoXDp7Ge1OfdH1xsh1JfXMhMlUUi5j3qOovIngectFcReIWCA+8c67LI2tkn4bjdsIPT8UhiharUBJu89hRvHjwdO2YsCQvBtpjdAre3ElJFL2wAVwBvCkiCuAvVHVn8wNEZCOAjQCwePHikJczT5gGVGOT+VhK/qKSEsGsKndSEnVI2AC+WlXPiMi1AN4SkROq+qP6B1SD+k4AGB4eTmZytkPCNqDa+trxjo6v02ZV8d72++MeBlHPChXAVfVM9etZEfkBgDsA/Mj7t/qHWwMqry3v9TN209/tmPMm6qzAZYQicoWI/Lr9PYD7AByLamC9wG2x0WvLu70JJ6nB2/UA5KafmfMm6rwwdeDXAfgXETkC4CcA9qrq69EMqze4zUDdbk/qJpx68+cJrFRjuE5bKfzXuxbzJHiiLgucQlHVdwGsiHAsPcepAZXXzNSEOuliaRbWgGBhxkJhusQFSqIYsYywg1q1TG2uULnK5SSapCnNKjLz52Hyf9wX91CI+hoDeIe5tUx1qlCxUmLMAcMmfFog6nUM4B3mVgfulO8uleOL3LkWB0Y0Y4UJUfwYwEPy2qjjVQeepBmsnZf36jVeTwBWmBAlALsRhtBc9mcHaLttqlcdeGZ+yuEZ42FXjGxau2xOn3EnCv9nbxJR5zCAh+AVoAH3WXa+UMTFy8koF8xl0w3nUNb3GXer+XY7UZ6IuosplBBabdRpdRBx3JxKGusXXZtTQG6/Q0Tx4Aw8hFYbdZZcnayZqgBYmLF8b7bxOvmHiOLHGXgbnE533zORn5NGuXhpBmOTefz43Xj7czf7n+tXth18eXI8UXJxBu6T04Llnok8Hl6Vw8KM1fDYQrGEp145mpgDhYHKzJuBmKi3MID75LZguf/EFDLz536QSVJPk7SVwpYHb417GEQUMaZQfGq3s2DcctWj19irhKh3MYD75FZRYi9YJqnaZGHGwoHRNXEPg4g6jCkUn5w2udgldW73dVrGcv7nu//z13f82kQUP87AfRoZymH81DnsOvQByqpIieDhVY0VGk+/erzWTXCBNdDxPPh0adbx9v0npjp6XSJKBs7AfRqbzGPPRL524npZFXsm8rVt8wBwaeZXAfX8dLi2sM57IP1Jal6eiKLFAO5Tq23zT796PNIZ9xc+s6i2gcYpVZK2UnPKF23sFEjUH5hC8cmrr8mS0b2RX+/9/1dsWIh06noIgFvdifpY6AAuIikA4wDyqvpA+CElU7f7mjS/YXjtiHRrZ0tEvS2KGfjjAN4GcGUEz5Uo9bPetEvFR6f4TYNwqztR/woVlUTkRgD3A3gumuEkR/PWebeKj05gGoSI/Ag7rXwWwJ8AcI1uIrJRRMZFZHxqypzyNqdFy05KibDjHxG1JXAKRUQeAHBWVSdE5LfcHqeqOwHsBIDh4eEEtXfy1u1SvFlVvLf9/q5ek4jMFmYGvhrAQyLyPoC/BbBGRF6IZFQJ0O1SPJb+EVG7AgdwVX1KVW9U1SUAHgWwT1Ufi2xkXTI2mcfq7fuwdHQvVm/fV9uY4/d8SJuVElgD/rbfND9KANxzy6DvaxERAX1eB+51ajwA+IzHWJixau1a7ZPdUyIoq0JQOQTYlrZSuH3xVfjxO+dqtyuAPRN5DN+8iLlvIvJNtIunDgwPD+v4+HjXrtfK6u37HGu7F2YsfHJpBqVy69fm2Ran3DhtwLGDfLNcNs0ugkQ0h4hMqOpw8+19PQN3W6j028ek/kR3N0512k++dLit8RAROenrXihhFg6tlODipZk5ufMw1+VCJhG1o68DeLsLlTYBAK2cfWmfj/nUK0d9B3Gv3uJERH71XApl89jRhp7dG+68Cc+MLHd8rJ3a+NbuI7U2sX4ogNJs4+PtzoR+FiHtx7CHCRGF0VMBfPPYUbxw8HTt57Jq7efhmxc5BsyRoZxrTrpd7eSw2cOEiMLqqSqUzzz1944zaQGwwEo1bI23y/ty2TSmL8+EPoABYBUJEXVGX1ShuKVBFJjT18R+ZL5QhDUgsFLiq2zQDXPYRNRtPbWImZJgB5GVZhVXzJ9XOwEnl03jsbsWw0r5e76UCBtQEVHXGTcDd9oYYwfODXfe1JADb8eFYgmHt9zXcFt93jybsXBhujSn7aKVEux4ZAWDNxF1nVE58Oat7zZ7K/vIUA6f+9N/QDFA724/+euxyXzDyfP11yUi6hS3HLhRAdxt6ztQyUFvW1cpF3QK8l7qf5elfUSUNG4B3KgcuFeZXn0dth2MvdjZbfsABQANJ/C0uzmHiKjbjMqBtzpY2D4hPpu2kE1btVSHE7uE0E6brN6+b86svZ3NOURE3ZaIGbhbT+5mfre+F4ol/PLTUsv+3PUzerfZPRtMEVFSxR7Amw8P9kpd2OmRbNpq+byzCvyHBZXSQDcK1N4w2GCKiEwTewB3OjzYTl04GRnK4fCW+7D6M4taPndhuoQDo2vw7PqVrjN3+w3jnlsG2WCKiIwSewAPkrrYPHYUB9451/K57dmzPXN3m40XS2XsPzFVewxPhyciE8S+iJnNWI59SLIZ9zTJi4dab9axBqRh9mw3j1o6uhdOhZNnCkU2mCIiowSegYvIAhH5iYgcEZHjIrI1yPO4laF7lae3Kl0XAXZ8zXl3JHPdRNQrwszALwFYo6qfiIgF4F9E5B9U9aDbLxSmS1i9fV/DRpkLLqV+brf7onCdSW9au2zORh/muonIRIEDuFa2cH5S/dGq/vGcG+cLRcxUc9v24uFVLvXaXjPijDWAaY/t8l6/y8MUiKhXhMqBi0gKwASA/wjgf6nqIYfHbASwEQBSVw423FcslbHAGkC6qVd3qxnxt9d9Hn+8+zBmHd4u/Mymmesmol4QqgpFVcuquhLAjQDuEJHbHB6zU1WHVXU4lblqznMUpkttV3+MDOXw3d9dWasqsdvIsnKEiPpJJFUoqloQkX8E8GUAx9r53Ruy6UAzYs6iiajfBQ7gIjIIoFQN3mkAXwTwHa/fGWg6cMFOd3j1+CYiImdhZuDXA3i+mgcfALBbVf+P1y/ksmlcl003BGqgsf2rvbgJuFeSEBFRAvqBu/X45gHBREQVie0Hzi6ARETBxB7AuTOSiCiY2AO4U49v7owkImot9mZW3BlJRBRM7AEcYE03EVEQsadQiIgoGAZwIiJDMYATERmqqxt5RGQKwKmIn/YaAL+I+Dk7geOMjgljBDjOqPXzOG9W1cHmG7sawDtBRMaddiglDccZHRPGCHCcUeM452IKhYjIUAzgRESG6oUAvjPuAfjEcUbHhDECHGfUOM4mxufAiYj6VS/MwImI+hIDOBGRoYwM4CKyQER+IiJHROS4iGyNe0xeRCQlIpMi4nliUZxE5H0ROSoih0VkvPVvxENEsiLysoicEJG3ReQ34x5TMxFZVn0d7T+/FJEn4h6XExF5svr/0DER2SUiC+IeUzMRebw6vuNJex1F5C9F5KyIHKu7bZGIvCUiP6t+Xdip6xsZwAFcArBGVVcAWAngyyJyV7xD8vQ4gLfjHoQP96jqyoTX2v45gNdV9RYAK5DA11VVT1Zfx5UAVgGYBvCDeEc1l4jkAPwRgGFVvQ1ACsCj8Y6qkYjcBuAPANyByr/3AyLy2XhH1eCvUTnMvd4ogB+q6mcB/LD6c0cYGcC14pPqj1b1TyJXY0XkRgD3A3gu7rGYTkSuBHA3gO8BgKpeVtVCrINq7V4A76hq1DuQozIPQFpE5gHIADgT83iafQ7AQVWdVtUZAP8E4HdiHlONqv4IwLmmm78K4Pnq988DGOnU9Y0M4EAtLXEYwFkAb6nqoZiH5OZZAH8CYDbmcbSiAN4UkQkR2Rj3YFz8BoApAH9VTUk9JyJXxD2oFh4FsCvuQThR1TyAPwNwGsBHAC6o6pvxjmqOYwDuFpGrRSQD4LcB3BTzmFq5TlU/AoDq12s7dSFjA7iqlqsfUW8EcEf1o1aiiMgDAM6q6kTcY/FhtareDuArAL4hInfHPSAH8wDcDuB/q+oQgIvo4MfTsERkPoCHAHw/7rE4qeZmvwpgKYAbAFwhIo/FO6pGqvo2gO8AeAvA6wCOAJiJdVAJYmwAt1U/Qv8j5uahkmA1gIdE5H0AfwtgjYi8EO+QnKnqmerXs6jka++Id0SOPgTwYd2nrZdRCehJ9RUAP1XVn8c9EBdfBPCeqk6pagnAKwC+EPOY5lDV76nq7ap6Nyrpip/FPaYWfi4i1wNA9evZTl3IyAAuIoMikq1+n0blP8QTsQ7Kgao+pao3quoSVD5K71PVRM1wAEBErhCRX7e/B3AfKh9dE0VVPwbwgYjYB6beC+DfYhxSKxuQ0PRJ1WkAd4lIRkQEldczcYvCInJt9etiAOuQ7NcUAF4F8PXq918H8HedulAijlQL4HoAz4tICpU3od2qmtgSPQNcB+AHlf+HMQ/A36jq6/EOydU3AbxYTU+8C+D3Yh6Po2q+9ksA/jDusbhR1UMi8jKAn6KSlphEMrer7xGRqwGUAHxDVc/HPSCbiOwC8FsArhGRDwFsAbAdwG4R+X1U3iS/1rHrcys9EZGZjEyhEBERAzgRkbEYwImIDMUATkRkKAZwIiJDMYATERmKAZyIyFD/H+wXDFF+2MgWAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "y_meas_train = transformers[0].untransform(train_dataset.y)\n",
    "y_pred_train, y_pred_train_stds = predict_with_error(model, train_dataset.X, transformers[0])\n",
    "\n",
    "plt.xlim([2.5, 10.5])\n",
    "plt.ylim([2.5, 10.5])\n",
    "plt.scatter(y_meas_train, y_pred_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EOtLvq02A2bb"
   },
   "source": [
    "We now do the same for our test set. We see a fairly good correlation! However, it is certainly not as tight. This is reflected in the difference between the R2 scores calculated above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "uswXIr2vqq-W"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7fc04023b590>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAd7UlEQVR4nO3dbWxc1ZkH8P/f4wHslOIAhgVDCF2hUAFKTCxKGy0qoW3aQsGFvoAWCVVV0w9VC4jNKkiVAGklXKVVt5+qjegLEjSFEnBp2eVFhLZatKG1cVJISYRaIDChxBUYBDHgOM9+8IwZ37n3zn2buffM/H9S5GQYz308Ic8985znnEMzg4iIuKcn7wBERCQZJXAREUcpgYuIOEoJXETEUUrgIiKO6m3nxU488URbuXJlOy8pIuK8ycnJf5jZoPfxtibwlStXYmJiop2XFBFxHsmX/B5XCUVExFFK4CIijlICFxFxlBK4iIijlMBFRBylBC4i4iglcBERRymBi4g4SglcRMRRTRM4yZ+SPEjy2brHjif5GMnnq1+XtzZMERHxijIC/zmAz3oe2wzgcTM7C8Dj1T+LiEgbNU3gZvYHAK97Hr4CwJ3V398JYDTbsEREpJmkNfCTzexVAKh+PSnoiSQ3kpwgOTE9PZ3wciIi4tXySUwz22pmI2Y2MjjYsBuiiIgklDSBv0byFACofj2YXUgiIhJF0gT+IIDrqr+/DsCvswlHRESiitJGuA3A/wFYRfIVkl8HMAbg0ySfB/Dp6p9FRKSNmp7IY2bXBPynSzKORUREYtBKTBERRymBi4g4SglcRMRRSuAiIo5SAhcRcZQSuIiIo5TARUQcpQQuIuIoJXAREUcpgYuIOEoJXETEUUrgIiKOUgIXEXGUEriIiKOUwEVEHKUELiLiKCVwERFHKYGLiDgqVQIneT3JZ0nuIXlDRjGJiEgEiRM4yXMBfAPABQBWA7iM5FlZBSYiIuHSjMA/CmCnmR0ys8MAfg/gi9mEJSIizaRJ4M8CuIjkCST7AXwewOneJ5HcSHKC5MT09HSKy4mISL3ECdzMngPwPQCPAXgYwG4Ah32et9XMRsxsZHBwMHGgIiKyVKpJTDP7iZmdb2YXAXgdwPPZhCUiIs30pvlmkieZ2UGSKwBcCeDj2YQlIiLNpErgALaTPAHAHIBvmdkbGcQkIiIRpErgZvYvWQUiIiLxaCWmiIijlMBFRBylBC4i4iglcBERRymBi4g4SglcRMRRSuAiIo5SAhcRcZQSuIiIo5TARUQcpQQuIuIoJXAREUcpgYuIOCrtdrIi4rjxqQq2PLIPB2ZmcepAHzZtWIXR4aG8w5IIlMBFutj4VAU33/8MZufmAQCVmVncfP8zAKAk7gCVUES62JZH9i0m75rZuXlseWRfThFJHBqBi3SxAzOzsR7Pgko22VECF+lipw70oeKTrE8d6GvJ9dKUbJT4G6UqoZC8keQeks+S3EbymKwCE5HW27RhFfrKpSWP9ZVL2LRhle/zx6cqWDe2A2dufgjrxnZgfKoS63pJSza1xF+ZmYXhg8Qf9/qdJnECJzkE4DsARszsXAAlAFdnFZhIO6VNTK4aHR7CVWuHUCIBACUSV60d8h3ZZpFEk5ZsVKv3l3YSsxdAH8leAP0ADqQPSaS9OmV0l+QmND5VwfbJCubNAADzZrhr536sue3Rhu/PIokGlWaalWzyqNW7IHECN7MKgO8D2A/gVQBvmtmj3ueR3EhyguTE9PR08khFWqQTRndJb0J+PzsAzMzONXx/Fkk0bsmmJmni73RpSijLAVwB4EwApwJYRvJa7/PMbKuZjZjZyODgYPJIRVqkE0Z3SW9CYT/j7Nw8bn1wz+KovqdaZvGKk0RHh4dw+5XnYWigDwQwNNCH2688r+lkZNLE3+nSdKF8CsALZjYNACTvB/AJAHdlEZhIu7S7E6MVkt6Egn72mpnZOczMzgHAYpmlXpIkOjrsX2Nv9j0A1IXikSaB7wdwIcl+ALMALgEwkUlUIm20acOqJa1tQP6ju7gtc1FvQt7XvfjsQWyfrPiWUYKUSBwxC4yrWexJ2wGTJP5OlziBm9lTJO8D8DSAwwCmAGzNKjCRdina6C5Jr3SUm5Df626frOCqtUN46M+v4o1Dc5HiO2KGF8YuTRS7lu5ni+bzsahVRkZGbGJCg3SRMOvGdviOpocG+vDk5vUNj9dGtJWZWZRIzJthyOcm1Ox1vSPjg2/NYu5IY3wDfWXsuuUziWKP+7PJApKTZjbifVwrMUUKJk492zuinTdbHHl7R7TNXre+RDE+VcEN9+zyfX7AXGaka3TChHGRaDMrkYKJ0zJ364N7fLtPbrp3d0MLYZzXDetemQkptRzXVw59XO2A2VICF2mTqAttorbMjU9VFjtEvObNGvq447TihY2Ie8jA2ING57XH1Q6YLSVwkTbwW2hz4z278N3xZxoSO4BIvdLNery9feBxls2HjYj9bg41QaPz2uNJ+8DFn2rgIm3gt9DGANy1cz/u+dPLmJtfaCaodWXcfuV5TSf1otSNKzOzGJ+qLHaAeJfNb5+sYOSM4xsSqF9XS73azcH7fVHaGdUOmB2NwEXaIGyxTC151wStoPSO1Af6/evNXrXRctBqzRvu2YV/vvm/8d3xZxYfrx8pB/G7gahE0l5K4CItNj5VQUjjhi9vcvQrwbxxaA49EV64dkMIG7HXNrHyJvEnN68PTOJ+ZRaVSNpLJRSRFtvyyD7EXW3hTY5Bm04dsYW+7Ddn53BcXzlwUrPW2x32SQAAtj31Mv5j9Lwlj8VdqaoSSfsogYu0QP2imGbJu9RDzB/54Fl+yTFs9Lzs6N7FhTVBC2Vqq0vD6tqA/34nRVupKh9QAhdpIsreHfXPOa6vjHfeP9xQ2w5y7NG9WHZ0b+jeIT3VFZZ+6pN72Gi5PhEHjcRLAX2AGlUXk2rgIiGi7LPtfc7M7Fzk5A0Ab87OYdOGVTh1oA8HZmax5ZF9i69fe+2g5A00dniE1aBrde1rL1zh+1rXfOz0yHFL/jQCFwkRts92/Yg2zm5+Xsf1lQM3eGr22vXlFu8nhR9+dU3gqLlW59721MuYN0OJxDUfO72h/i3FpgQuEiLK3h1R9/FY3l/Gu3NHGsobJAJvEmGvXSIXR9dJdvkbOeN4PLF3GgdmZvFPxx2DkTOOb3iOToIvNiVwkRBRFqZE6e7oK5dwyxfOAdA4GXhjwKZRzTpHjpiFfgqo/6TQbB9wv4SvrV+LTzVwkRBRFqb4PaeHH+z/Ub9kvVaDfmHsUjy5eT1Gh4dCN3gKWwBT/31hnxT86vh379zf9Ai2TjgrtNMpgYtzkpy+nsbRvR/8M1neX25YmOKdOFzeX0aJRG3esbZkPenmVX6Ldco9XJLcw24CQcv4/UQpDWnr1+JQCUWc0s6P9d5rAcC71RMO/GrDtb1L1o3taDjdpn7kGlRT9nt83dgOHPHJth86pnfJz3vx2YO4a+f+huddfPYg7vZ5PEj9drCdcFZop1MCF6dE6Qpp9bVufXAP3jt8JPAmEjRCrczMYtN9uxs2rqp9n1/8Qa/1xqE5rBvbsZjw33nvsO/zntg7HalGX1PfBl7Es0JlqcQlFJKrSO6q+/UWyRsyjE2kQVAiapagkpRdgpLnzOxc4CEKZ25+CD0hR9ZE3biqJmy0W1/TDltC71eiCfLGobnF90b7mhRfmkON9wFYAwAkSwAqAB7IJiwRf6WAFYlBKwiB5GWXOCNXAEu2aY0jrKYcZfl7mOP6yr4lmkPvHw48xNj7qUAJu7iymsS8BMBfzeyljF5PxFdQcgxLmkGlkNt+syf0WkGTi8sjbuMaVQ8Z+sngmPIH/0wHAo4sC1K7r3m7X275wjmBo/L6TxPtmCSW5LJK4FcD2Ob3H0huJDlBcmJ6ejqjy0m3CtraNO6+1cDScoGfoBJCWPJLYt7Md5l+7ZND/Uj5vcM+x8SH8DshJ2hv8CgxSbGkTuAkjwJwOYBf+f13M9tqZiNmNjI4OJj2ctLlkhwYEFZHbtbTXBu5/vCrawAAN96zC1se2Yer1g4tJvag8k3Y6e1BvN0qfp8c4rys92ev7wlPEpMUSxYj8M8BeNrMXsvgtURCJZlYC0vuUXqa/RbCbJ+sYNOGVXhh7FL84CurfUfk3qpOOeK/tlpMQbEZFvrA6/UAKJeWPuZ3Y0u6b4t6v4spizbCaxBQPhFphbgTa6PDQ7j1wT2+nRpRepqbtS56JwlJ+PZuz0WsftRiCppEHar2iXv7xutjCNq3JCwREwjctla938WUKoGT7AfwaQDfzCYckda49fJzEvc0R1mRWEvk41MV3BCwt0kUxMIIf93Yjob9SupjDrqJNbuxhd0Unty83nfxknq/iytVAjezQwBOyCgWkZZJc6pMsxWJ9asyw3rAmyE+WOJeK9NctXZoccfALHYDbLY4R6fvuIUWs2c1jZGREZuYmGjb9USyEDQqvf3Khb2zo/Zpl3v8yyj95R4sX3Z06Mg4S9oi1j0kJ81sxPu4ltKLNNFsr5Ioybuv3IPZgCL4obkjONTGjaO0OKdzKIGLRBB3r5J65R7isN+sZgTHxVy4I91FCVz0kboqyfsQVB8vkThi1nTZejNvvbuw2Kgb/z6kOe0H3uWiHNrbDZK+D0ELi37wldWLy9b9VkNGdcSAWx8MX/Jf0+590iV/SuBdrtNPXYma1JK+D1EWFqXtoQ7aabCebsTdSSWULtfJp6747UK46Ve7cdtv9mDm0NySMkma96HZpKBf6165hwAbt5dNqp37pEtxKIF3OddOXYlapx6fquCme3c3rCqcO2KL9ej6bWVb+T4EdbF4H5s59D7eeb+xo6V+98Ogn7+Tb8QSTAm8y7l06krUfb1rz4uyL3dtlNrsfUg70Rtl5eT4VGXJiT3Awv4mtdPsw35+127Ekg3VwLucS6euRK1Tx92wqTIzu/g9tZ0F69+HdtWXR4eHsOVLq5f8XWz50uolI/ignz/JLo3iPo3AxZmFHUFboHrLBHHLBrX9R4CFfbDr9xsBsqkvRx3Bh/1dhJVJtAS+OymBixPGpypL9gqp5y0TxDrE1+c1vck5bX056ZFuXs3KJK7ciCU7KqGIE7Y8ss83eQMf7N5XK2lEPcS3v9wT+Jr1yTmojhy1vpxVq6bKJOKlBC5OaDbara9Le+v6QSfmHArZoLs+OadNnFl1iLg0XyHtoRKKOCFKWcR7yEItsa3c/FCsa3mT8+jwECZeeh3bnnoZ82YokbhqbfRyRZYdIiqTSD2NwCWRdi/b3rRhVcORYX68iTJJXN5R7fhUBdsnK4ttifNm2D5ZifzaKn1Iq2gELrFlNSkX5Tq1roqB/jLmI6xa9JZL4taZhwb6Gn6GtF0o6hCRVlECl9jasWzbe5OIupufd/FOnDpz0Kg4ixq2Sh/SCiqhSGztWLad9PT0IZ+WwihKZOCEYNouFJFWSZXASQ6QvI/kXpLPkfx4VoFJe8WpabcjoSW9GVx89uCSP2/asAqlnvDaeW37VwC+74Fq2FJUaUfgPwLwsJmdDWA1gOfShyTtFnepeDsSWtKbwT1/enlJ3KPDQ/jBl1dj2VFL4112VGlJKx6AwPdA7XtSVIkPNSb5YQC7AXzEIr6IDjUupnVjO2IfqNvqU3zGpyq48Z5dgQttwiQ5CDjJeyDSLq041PgjAKYB/IzkagCTAK43s3c8F94IYCMArFixIsXlpFWS1LRbPSk3OjyEG+7Zleh74yxxr92EoqzIFCmaNCWUXgDnA/ixmQ0DeAfAZu+TzGyrmY2Y2cjg4KD3P0sBFHWSzjshWdNf7glcXQkAPWTTWr63bBQk7/dAJEyaBP4KgFfM7Knqn+/DQkKXNku7qMavpl3boS/PsxX94ir3EO/NW+he3/NmTWv5UbpcijpRqbMvpSZxCcXM/k7yZZKrzGwfgEsA/CW70CSKLBbV1C80qczMLtmhr1WLdOLG1ezUGgDo4cIhwPWC+tPDSiMEItX1Wz0PEHTNdiyiEjcknsQEAJJrANwB4CgAfwPwNTN7I+j5msTMXtaTb0WfzAvb1yRou1kCeGHs0iWPpf05vYkUWBixt7o7peh/P9IarZjEhJntAtDwotI+WS+qaffZit5R7MVnD+KJvdOJRrVxNo1Ke5RcXocI6+xLqaeVmI7LegKynROafv3nd+3cH9qPPtBX9n0tYmERT9T+9LS93Xkl0qJOOEs+lMAdl/WimnauOowykeg9+ODWy89B2WdlpQHYPlnBVWuH2rLgZqDf/0YS9HhWtCpU6mkzK8dF2ekuzmRbO3fOizparX9eLY6b7t3d0IkyOzePJ/ZOJ6phx50MDJo6SjGlFIl2NpR6SuCOCEvCYYtqkiSqdu2cF/XsSm95YHR4CDcGLPKJelNIW8N+c9Z/d8Sgx7OknQ2lRiUUB8Tdq6ReVucxxhG1T3nThlW+5ZB6QeWBoJrvQH850rXT1rBVi5YiUAJ3QJokHCVRZbkwJM7NZnR4CB86JvhDYFgN23eRT4l4+93Dka6dNgGrFi1FoATugDSjxWaJKs3o3k/cm81MyEENT25eH1rm8XaRLDuqF3NHGuviftdOm4C1Q6EUgWrgDkhzKG6zfues+5nj3myCfjYCi1u5BvHWgs8MWOTjd+0sJgNVi5a8KYE7wC8Jl0vEO+8dxpmbHwpdANMsUWXdzxyUkIPa6zZtWOW7baxVY46TIOPe6JSAxXUqoTjA+3F9eX8ZMGBmdi7SApjR4SE8uXk9Xhi7tKEskfVkXNDp8W+/eziwDp7VVq6qS0u3UQJ3RH0S7vep9XpFneTMOumNDg9h2VGNH+zmjhhuune3bxIP2jY27k1EdWnpNiqhOCjJApggrVgYEtQLPW/m24Oedl+SeiqLSDdRAndQ0gUwQbJOemHx+U2QanWhSDJK4A7yG7H68Z7QnlbUJfnN4gvqClHCFolHCdwxtSQ6OzePEol5s8WvXk/snc70ulGX5IftVwJotaJIVpTAC8xvr+ztk5XFJDpvhr5yKXCkW5mZbdpLHVWzBTpBI/Osatsi0kgJvKD8Rrx379zf0HJXPxL3k9VxW0ETorWReNjIXLVtkdZQAi8ovxFvUONg2Eg8q1NigiYmS2ToSk7VtkVaJ1UfOMkXST5DchdJHXaZoTiLWGr9zmGv1YqT6/vKpcCRv474Emm9LBbyXGxma/wO3JTkgib6vGscazXl0eGhwAUxx/WVE21YVZ/0tzyyz/e0m6wW4YhIfCqhFFTQ4par1g4FHvob9D0kIm1YVT9pOtBfxtvvHl5c8VmZmcX2yYrvykZNVIrkI20CNwCPkjQA/2VmW71PILkRwEYAWLFiRcrLdY8kE4BB3xPl9BrvpOkbPtu8ahGOSLHQUhziR/JUMztA8iQAjwH4tpn9Iej5IyMjNjFRvFJ5nDMjXYxp3dgO3wnIoYG+xfMjg57jRQAvjF2aKA4RSYbkpF+ZOtUI3MwOVL8eJPkAgAsABCbwIkp7uG1WMYT1e6eNKcpeI2mPEmuFIt5YRYok8SQmyWUkj639HsBnADybVWDtkseZkfX8TsS5e+f+TGOKsktflMTcztp21icFiXSiNCPwkwE8QLL2Or8ws4cziaqNsj7QIK44/d5hMTUbrTbrxw46NGLZUb14c3au7SPgrE8KEulEiRO4mf0NwOoMY8lFmuPKshDnRhEUUxZloKJNRuZ9YxVxQde3EWa5F3USYWdC1o/Ew2LKarRapFWTed9YRVzQ9Sfy5H2KS9AKx3+9cEXkmDpxtKrj0USa6/oROJDvyDOL0kUnjlaLVtIRKaJUfeBxFbUPPGvtbn/z1sCBhdGqzoMU6Qwt6QPvJlGTch595RqtinQnJfAI4iTlvNrfsigDaeGMiFu6fhIzijiLfVydUNTCGRH3KIFHECcpB00cFn1CMe8VqSISnxJ4BHGSsqvtb65+chDpZkrgEcRJynn3lSeV9JND2pN+RCQ5TWJGELfLo0grGqNKsiK1CDs5inQzJfCIXEzKcSRpRdSGUyL5UgKXRXFvUqqbi+RLNXBJzNWOG5FO4ewIvJ2LTrTAxV/eOzmKdDsnE3g7J880URdMS/hF8uVkAm/n5Jkm6sJ1+uSuSJE5WQNv5+SZJupEpKhSJ3CSJZJTJH+bRUBRtHPyTBN1IlJUWYzArwfwXAavE1k7l6u7ujReRDpfqgRO8jQAlwK4I5twomnncnVXl8aLSOdLdSIPyfsA3A7gWAD/ZmaX+TxnI4CNALBixYq1L730UuDrqV1PRKRR5ifykLwMwEEzmyT5yaDnmdlWAFuBhSPVgp7Xye16ujGJSCukKaGsA3A5yRcB/BLAepJ3JX2xTt2PWgcliEirJE7gZnazmZ1mZisBXA1gh5ldm/T1OrVdr1NvTCKSv8L0gXdqu16n3phEJH+ZJHAz+53fBGYcndqu16k3JhHJX2FG4J3artepNyYRyV+h9kLpxH01tOGTiLRKoRJ4p+rEG5OI5K8wJRQREYlHCVxExFFK4CIijlICFxFxlBK4iIijlMBFRBylBC4i4iglcBERRymBi4g4SglcRMRRHbeUXqffiEi36KgE3snHsomIeHVUCUWn34hIN+moBK7Tb0Skm3RUAtfpNyLSTRIncJLHkPwjyd0k95C8LcvAktDpNyLSTdJMYr4HYL2ZvU2yDOB/Sf6Pme3MKLbYdPqNiHSTxAnczAzA29U/lqu/LIug0tDpNyLSLVLVwEmWSO4CcBDAY2b2lM9zNpKcIDkxPT2d5nIiIlInVQI3s3kzWwPgNAAXkDzX5zlbzWzEzEYGBwfTXE5EROpk0oViZjMAfgfgs1m8noiINJemC2WQ5ED1930APgVgb0ZxiYhIE2m6UE4BcCfJEhZuBPea2W+zCUtERJpJ04XyZwDDGcYiIiIxdNRKTBGRbqIELiLiKCVwERFHKYGLiDhKCVxExFFK4CIijlICFxFxlBK4iIijuLArbJsuRk4DeCnjlz0RwD8yfs1WUJzZcSFGQHFmrZvjPMPMGnYDbGsCbwWSE2Y2kncczSjO7LgQI6A4s6Y4G6mEIiLiKCVwERFHdUIC35p3ABEpzuy4ECOgOLOmOD2cr4GLiHSrThiBi4h0JSVwERFHOZnASR5D8o8kd5PcQ/K2vGMKQ7JEcopkYU8sIvkiyWdI7iI5kXc8QUgOkLyP5F6Sz5H8eN4xeZFcVX0fa7/eInlD3nH5IXlj9d/QsyS3kTwm75i8SF5fjW9P0d5Hkj8leZDks3WPHU/yMZLPV78ub9X1nUzgAN4DsN7MVgNYA+CzJC/MN6RQ1wN4Lu8gIrjYzNYUvNf2RwAeNrOzAaxGAd9XM9tXfR/XAFgL4BCAB/KNqhHJIQDfATBiZucCKAG4Ot+oliJ5LoBvALgAC3/fl5E8K9+olvg5Gg9z3wzgcTM7C8Dj1T+3hJMJ3Ba8Xf1jufqrkLOxJE8DcCmAO/KOxXUkPwzgIgA/AQAze9/MZnINqrlLAPzVzLJegZyVXgB9JHsB9AM4kHM8Xh8FsNPMDpnZYQC/B/DFnGNaZGZ/APC65+ErANxZ/f2dAEZbdX0nEziwWJbYBeAggMfM7KmcQwrynwD+HcCRnONoxgA8SnKS5Ma8gwnwEQDTAH5WLUndQXJZ3kE1cTWAbXkH4cfMKgC+D2A/gFcBvGlmj+YbVYNnAVxE8gSS/QA+D+D0nGNq5mQzexUAql9PatWFnE3gZjZf/Yh6GoALqh+1CoXkZQAOmtlk3rFEsM7MzgfwOQDfInlR3gH56AVwPoAfm9kwgHfQwo+naZE8CsDlAH6Vdyx+qrXZKwCcCeBUAMtIXptvVEuZ2XMAvgfgMQAPA9gN4HCuQRWIswm8pvoR+ndorEMVwToAl5N8EcAvAawneVe+IfkzswPVrwexUK+9IN+IfL0C4JW6T1v3YSGhF9XnADxtZq/lHUiATwF4wcymzWwOwP0APpFzTA3M7Cdmdr6ZXYSFcsXzecfUxGskTwGA6teDrbqQkwmc5CDJgerv+7DwP+LeXIPyYWY3m9lpZrYSCx+ld5hZoUY4AEByGclja78H8BksfHQtFDP7O4CXSa6qPnQJgL/kGFIz16Cg5ZOq/QAuJNlPklh4Pws3KUzypOrXFQCuRLHfUwB4EMB11d9fB+DXrbpQb6teuMVOAXAnyRIWbkL3mllhW/QccDKABxb+DaMXwC/M7OF8Qwr0bQB3V8sTfwPwtZzj8VWt134awDfzjiWImT1F8j4AT2OhLDGFYi5X307yBABzAL5lZm/kHVANyW0APgngRJKvALgFwBiAe0l+HQs3yS+37PpaSi8i4iYnSygiIqIELiLiLCVwERFHKYGLiDhKCVxExFFK4CIijlICFxFx1P8D9M3QWa37UMwAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "y_meas_test = transformers[0].untransform(test_dataset.y)\n",
    "y_pred_test, y_pred_test_stds = predict_with_error(model, test_dataset.X, transformers[0])\n",
    "\n",
    "plt.xlim([2.5, 10.5])\n",
    "plt.ylim([2.5, 10.5])\n",
    "plt.scatter(y_meas_test, y_pred_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jNogRsV4p5Xa"
   },
   "source": [
    "We can also write a function to calculate how many of the predicted values fall within the predicted error range. This is done by counting up how many samples have a true error smaller than its standard deviation calculated earlier. One standard deviation is a 68% confidence interval."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "Xihho3Ndp5gL"
   },
   "outputs": [],
   "source": [
    "def percent_within_std(y_meas, y_pred, y_std):\n",
    "    assert len(y_meas) == len(y_pred) and len(y_meas) == len(y_std), 'length of y_meas and y_pred must be the same'\n",
    "\n",
    "    count_within_error = 0\n",
    "    for i in range(len(y_meas)):\n",
    "        if abs(y_meas[i][0]-y_pred[i]) < y_std[i]:\n",
    "            count_within_error += 1\n",
    "\n",
    "    return count_within_error/len(y_meas)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the train set, >90% of the samples are within a standard deviation. In comparison, only ~70% of the samples are within a standard deviation for the test set. A standard deviation is a 68% confidence interval so we see that for the training set, the uncertainty is close. However, this model overpredicts uncertainty on the training set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9355371900826446"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "percent_within_std(y_meas_train, y_pred_train, y_pred_train_stds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "YpgKQswIRjvN",
    "outputId": "af83cdfd-7f1a-45ab-a244-c26a2b40c4a0"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7368421052631579"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "percent_within_std(y_meas_test, y_pred_test, y_pred_test_stds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Xsa03OFu8taf"
   },
   "source": [
    "We can also take a look at the distributions of the standard deviations for the test set predictions. We see a very roughly Gaussian distribution in the predicted errors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 265
    },
    "id": "na23a9AU8s8d",
    "outputId": "6d5ccd5a-74ee-4c6e-fc4c-02ac6a1c356e"
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAANwklEQVR4nO3df6xf9V3H8edrFLLpftDa266BsaumbuAibF4Zij+2VZSBs2jEgDoaUtNoNGGJU+v+2GKMSRcTo0aNNkjWRUVJBrbChmvuRNQB47KVX5ZZZBUJDS1sjjHNtPD2j+9hu5be+z33x/d774c+H8nN+fE933tePbl99Xw/95zTVBWSpHa9YqUDSJKWxiKXpMZZ5JLUOItckhpnkUtS49aMc2fr16+vycnJce5Skpp33333PV1VE3O9PtYin5ycZGZmZpy7lKTmJfn3+V53aEWSGmeRS1LjLHJJapxFLkmNs8glqXEWuSQ1rtflh0kOA18BngeOV9VUknXAXwOTwGHgp6vqS6OJKUmay0LOyN9ZVRdU1VS3vBOYrqrNwHS3LEkas6UMrWwF9nTze4ArlpxGkrRgfe/sLOCTSQr406raDWysqiMAVXUkyYaTvTHJDmAHwDnnnLMMkU8dkztvW5H9Ht51+YrsV9Li9C3yi6vqya6s9yd5pO8OutLfDTA1NeV/RyRJy6zX0EpVPdlNjwK3ABcCTyXZBNBNj44qpCRpbkOLPMk3J3nNi/PAjwAPAfuAbd1m24C9owopSZpbn6GVjcAtSV7c/i+r6vYk9wI3JdkOPA5cObqYkqS5DC3yqnoMOP8k658BtowilCSpP+/slKTGWeSS1DiLXJIaZ5FLUuMscklqnEUuSY2zyCWpcRa5JDXOIpekxlnkktQ4i1ySGmeRS1LjLHJJapxFLkmNs8glqXEWuSQ1ziKXpMZZ5JLUOItckhpnkUtS4yxySWqcRS5JjbPIJalxFrkkNc4il6TGWeSS1DiLXJIaZ5FLUuMscklqnEUuSY2zyCWpcRa5JDWud5EnOS3J55Lc2i2vS7I/yaFuunZ0MSVJc1nIGfl1wMFZyzuB6araDEx3y5KkMetV5EnOBi4Hrp+1eiuwp5vfA1yxrMkkSb30PSP/PeDXgBdmrdtYVUcAuumGk70xyY4kM0lmjh07tpSskqSTGFrkSX4MOFpV9y1mB1W1u6qmqmpqYmJiMd9CkjSPNT22uRj48SSXAa8EXpvkz4GnkmyqqiNJNgFHRxlUknRyQ8/Iq+o3qursqpoErgI+VVU/B+wDtnWbbQP2jiylJGlOS7mOfBdwSZJDwCXdsiRpzPoMrXxdVd0B3NHNPwNsWf5IkqSF8M5OSWqcRS5JjbPIJalxFrkkNc4il6TGWeSS1DiLXJIaZ5FLUuMscklqnEUuSY2zyCWpcRa5JDXOIpekxlnkktQ4i1ySGmeRS1LjLHJJapxFLkmNs8glqXEWuSQ1ziKXpMZZ5JLUOItckhq3ZqUDaPWZ3Hnbiu378K7LV2zfUqs8I5ekxlnkktQ4i1ySGmeRS1LjLHJJapxFLkmNs8glqXEWuSQ1bmiRJ3llks8kuT/Jw0l+s1u/Lsn+JIe66drRx5UknajPGfnXgHdV1fnABcClSS4CdgLTVbUZmO6WJUljNrTIa+C5bvH07quArcCebv0e4IpRBJQkza/XGHmS05IcAI4C+6vqHmBjVR0B6KYbRpZSkjSnXkVeVc9X1QXA2cCFSd7SdwdJdiSZSTJz7NixRcaUJM1lQVetVNV/AncAlwJPJdkE0E2PzvGe3VU1VVVTExMTS0srSXqJPletTCQ5s5t/FfDDwCPAPmBbt9k2YO+IMkqS5tHneeSbgD1JTmNQ/DdV1a1J7gJuSrIdeBy4coQ5JUlzGFrkVfUA8NaTrH8G2DKKUJKk/ryzU5IaZ5FLUuMscklqnEUuSY2zyCWpcRa5JDXOIpekxlnkktQ4i1ySGmeRS1LjLHJJalyfh2ZJYzO587YV2e/hXZevyH6l5eAZuSQ1ziKXpMZZ5JLUOMfIe1ipcVtJ6sMzcklqnEUuSY2zyCWpcRa5JDXOIpekxlnkktQ4i1ySGmeRS1LjLHJJapxFLkmNs8glqXEWuSQ1ziKXpMZZ5JLUOItckhpnkUtS4yxySWrc0CJP8oYkf5/kYJKHk1zXrV+XZH+SQ9107ejjSpJO1OeM/DjwK1V1LnAR8EtJzgN2AtNVtRmY7pYlSWM2tMir6khVfbab/wpwEDgL2Ars6TbbA1wxooySpHksaIw8ySTwVuAeYGNVHYFB2QMb5njPjiQzSWaOHTu2xLiSpBP1LvIkrwY+Bryvqp7t+76q2l1VU1U1NTExsZiMkqR59CryJKczKPG/qKqbu9VPJdnUvb4JODqaiJKk+fS5aiXAnwEHq+p3Z720D9jWzW8D9i5/PEnSMGt6bHMx8F7gwSQHunUfAHYBNyXZDjwOXDmShJKkeQ0t8qr6JyBzvLxleeNIkhbKOzslqXEWuSQ1ziKXpMZZ5JLUOItckhpnkUtS4yxySWqcRS5JjbPIJalxFrkkNc4il6TGWeSS1DiLXJIaZ5FLUuMscklqnEUuSY2zyCWpcRa5JDXOIpekxlnkktQ4i1ySGmeRS1LjLHJJapxFLkmNs8glqXEWuSQ1ziKXpMZZ5JLUOItckhpnkUtS4yxySWqcRS5JjRta5EluSHI0yUOz1q1Lsj/JoW66drQxJUlz6XNG/hHg0hPW7QSmq2ozMN0tS5JWwNAir6o7gS+esHorsKeb3wNcsbyxJEl9LXaMfGNVHQHophvm2jDJjiQzSWaOHTu2yN1JkuYy8l92VtXuqpqqqqmJiYlR706STjmLLfKnkmwC6KZHly+SJGkhFlvk+4Bt3fw2YO/yxJEkLVSfyw9vBO4C3pTkiSTbgV3AJUkOAZd0y5KkFbBm2AZVdfUcL21Z5iySpEXwzk5JapxFLkmNs8glqXEWuSQ1ziKXpMZZ5JLUOItckhpnkUtS4yxySWqcRS5JjbPIJalxFrkkNc4il6TGWeSS1DiLXJIaN/R55KvF5M7bVjqCXsZOxZ+vw7suX+kIWiaekUtS4yxySWqcRS5JjbPIJalxFrkkNc4il6TGWeSS1DiLXJIaZ5FLUuMscklqnEUuSY2zyCWpcc08NEvS8lrJB4X5wK7l5Rm5JDXOIpekxlnkktQ4x8glnTJerr8XWNIZeZJLk3w+yaNJdi5XKElSf4su8iSnAX8EvBs4D7g6yXnLFUyS1M9SzsgvBB6tqseq6n+AvwK2Lk8sSVJfSxkjPwv4j1nLTwBvP3GjJDuAHd3ic0k+v4R9jsJ64OmVDtFDCzlbyAjmXE6LypgPjyDJ/Fb8WPb8M8+V843zvWkpRZ6TrKuXrKjaDexewn5GKslMVU2tdI5hWsjZQkYw53JqISO8/HMuZWjlCeANs5bPBp5cwveTJC3CUor8XmBzkm9NcgZwFbBveWJJkvpa9NBKVR1P8svA3wGnATdU1cPLlmx8Vu2wzwlayNlCRjDncmohI7zMc6bqJcPakqSGeIu+JDXOIpekxp0yRT7scQJJ3pHky0kOdF8fXG0ZZ+U8kOThJP8w7oxdhmHH8ldnHceHkjyfZN0qzPm6JH+b5P7ueF67CjOuTXJLkgeSfCbJW1Yg4w1JjiZ5aI7Xk+QPuj/DA0neNu6MXY5hOd+c5K4kX0vy/nHnm5VjWM6f7Y7jA0k+neT8od+0ql72Xwx+GftvwLcBZwD3A+edsM07gFtXecYzgX8BzumWN6zGnCds/x7gU6sxJ/AB4MPd/ATwReCMVZbxd4APdfNvBqZX4Fj+IPA24KE5Xr8M+ASDe0suAu4Zd8aeOTcA3wP8NvD+lcjYM+f3AWu7+Xf3OZ6nyhl5C48T6JPxZ4Cbq+pxgKo6OuaMsPBjeTVw41iS/X99chbwmiQBXs2gyI+vsoznAdMAVfUIMJlk4xgzUlV3Mjg2c9kKfLQG7gbOTLJpPOm+YVjOqjpaVfcC/zu+VCfNMSznp6vqS93i3Qzu0ZnXqVLkJ3ucwFkn2e57u4/Zn0jyneOJ9nV9Mn4HsDbJHUnuS3LN2NJ9Q99jSZJvAi4FPjaGXCfqk/MPgXMZ3Mj2IHBdVb0wnnhAv4z3Az8JkORCBrdqD/2LPWa9fya0YNsZfNqZ16nyPPI+jxP4LPDGqnouyWXA3wCbRx1slj4Z1wDfDWwBXgXcleTuqvrXUYebpdejGTrvAf65quY7mxuVPjl/FDgAvAv4dmB/kn+sqmdHnO1FfTLuAn4/yQEG/9h8jvF+auhjIT8T6inJOxkU+fcP2/ZUOSMf+jiBqnq2qp7r5j8OnJ5k/fgi9nrkwRPA7VX11ap6GrgTGP6LkOW1kEczXMXKDKtAv5zXMhiqqqp6FPgCg3Hocen7c3ltVV0AXMNgLP8LY0vYj4/rWGZJvgu4HthaVc8M2/5UKfKhjxNI8vpurPTFj7CvAIYewHFmBPYCP5BkTTds8Xbg4Bgz9s1JktcBP8Qg80rok/NxBp9u6Mad3wQ8tpoyJjmzew3g54E7x/iJoa99wDXd1SsXAV+uqiMrHapVSc4Bbgbe2/fT9ikxtFJzPE4gyS90r/8J8FPALyY5Dvw3cFV1vzZeLRmr6mCS24EHgBeA66vqpJcwrWTObtOfAD5ZVV8dZ74F5vwt4CNJHmQwPPDr3Sed1ZTxXOCjSZ5ncMXS9nHle1GSGxlc1bU+yRPAh4DTZ2X8OIMrVx4F/ovBJ52xG5YzyeuBGeC1wAtJ3sfgKqGx/sPY43h+EPgW4I+7c8vjNeSJiN6iL0mNO1WGViTpZcsil6TGWeSS1DiLXJIaZ5FLUuMscklqnEUuSY37P6u7XdelWU/yAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist(y_pred_test_stds)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For now, this is the end of our tutorial. We plan to follow up soon with a deeper dive into uncertainty estimation and in particular, calibrated uncertainty estimation. We will see you then!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_JEhXwfSUMYD"
   },
   "source": [
    "## Appendix: Hyperparameter Optimization\n",
    "\n",
    "As hyperparameter optimization is outside the scope of this tutorial, I will not explain how to use Optuna to tune hyperparameters. But the code is still included for the sake of completeness. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ItUhfWMBAYMF"
   },
   "outputs": [],
   "source": [
    "%pip install optuna\n",
    "import optuna"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9OWKLnv4iGWp"
   },
   "outputs": [],
   "source": [
    "def get_model(trial):\n",
    "    output_variance = trial.suggest_float('output_variance', 0.1, 10, log=True)\n",
    "    length_scale = trial.suggest_float('length_scale', 1e-5, 1e5, log=True)\n",
    "    noise_level = trial.suggest_float('noise_level', 1e-5, 1e5, log=True)\n",
    "\n",
    "    params = {\n",
    "        'kernel': output_variance**2 * RBF(length_scale=length_scale, length_scale_bounds='fixed') + WhiteKernel(noise_level=noise_level, noise_level_bounds='fixed'),\n",
    "        'alpha': trial.suggest_float('alpha', 1e-12, 1e-5, log=True),\n",
    "    }\n",
    "\n",
    "    sklearn_gpr = GaussianProcessRegressor(**params)\n",
    "    return dc.models.SklearnModel(sklearn_gpr)\n",
    "\n",
    "def objective(trial):\n",
    "    model = get_model(trial)\n",
    "    model.fit(train_dataset)\n",
    "    \n",
    "    metric = dc.metrics.Metric(dc.metrics.mean_squared_error)\n",
    "    return model.evaluate(valid_dataset, [metric])['mean_squared_error']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jA74yOKWexnW"
   },
   "outputs": [],
   "source": [
    "study = optuna.create_study(direction='minimize')\n",
    "study.optimize(objective, n_trials=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "BABAYa2NpuiN",
    "outputId": "79c9e271-4524-461a-fbc1-ef9ce92260e4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'output_variance': 0.38974570882583015, 'length_scale': 5.375387643239208, 'noise_level': 0.0016265333497286342, 'alpha': 1.1273318360324618e-11}\n"
     ]
    }
   ],
   "source": [
    "print(study.best_params)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "Uncertainty_Estimation_using_Gaussian_Processes.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "mol",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
